############################################################
# Summary: This script reproduces the functional tests

# Output:
# - Figure 4: Confusion Matrices
  
df_pred_svm <- read_csv("data/results/test_sentences_prediction_svm.csv")

conf_matrix_svm_explicit <- df_pred_svm %>%
    pivot_longer(cols = starts_with("pred_label_seed"), 
                 names_to = "seed", 
                 values_to = "pred_label") |> 
    mutate(seed = gsub("pred_label_seed", "", seed),
           dataset_label = dataset,
           seed = as.integer(seed)) |> 
    #filter(seed == 42) |> 
    separate(dataset, c("label_type","training_dataset"),"_") |> 
    filter(label_type == "strong") |> 
    count(training_dataset, seed,expected_label, pred_label) %>%
    rename(Group = training_dataset, Actual = expected_label, Predicted = pred_label, Count = n) %>%
    mutate(Actual = substr(Actual, 1, 1), 
           Predicted = substr(Predicted, 1, 1),
           Color = ifelse(factor(Actual, levels = rev(unique(Actual))) == Predicted, "green", "red")) |> 
    mutate(across(c(Actual, Predicted), ~ case_when(
      . == "d" ~ "a",
      . == "e" ~ "c",
      . == "n" ~ "o",
      TRUE ~ .
    )))
  
  # Create individual confusion matrix plots
  testsentences_svm_explicit <- ggplot(conf_matrix_svm_explicit, aes(x = Predicted, y = Actual, fill = Color, alpha = Count)) +
    geom_tile(color = "white") +
    coord_equal() +
    geom_text(aes(label = Count), vjust = .5, fontface = "bold", alpha = 1) +
    scale_fill_manual(values = c("green", "red")) +
    scale_alpha(range = c(0.4, 1)) +
    facet_wrap(~ Group, nrow = 2) +
    guides(fill="none")+
    guides(label="none")+
    theme_minimal() +
    labs(title = "SVM")+
    theme(plot.title = element_text(size = 10, face = "bold",hjust = 0.5),
          legend.position="none")
  
  
  conf_matrix_svm_combined <- df_pred_svm %>%
    pivot_longer(cols = starts_with("pred_label_seed"), 
                 names_to = "seed", 
                 values_to = "pred_label") |> 
    mutate(seed = gsub("pred_label_seed", "", seed),
           dataset_label = dataset,
           seed = as.integer(seed)) |> 
    #filter(seed == 42) |> 
    separate(dataset, c("label_type","training_dataset"),"_") |> 
    filter(label_type == "combined") |> 
    count(training_dataset, seed,expected_label, pred_label) %>%
    rename(Group = training_dataset, Actual = expected_label, Predicted = pred_label, Count = n) %>%
    mutate(Actual = substr(Actual, 1, 1), 
           Predicted = substr(Predicted, 1, 1),
           Color = ifelse(factor(Actual, levels = rev(unique(Actual))) == Predicted, "green", "red")) |> 
    mutate(across(c(Actual, Predicted), ~ case_when(
      . == "d" ~ "a",
      . == "e" ~ "c",
      . == "n" ~ "o",
      TRUE ~ .
    )))
  # Create individual confusion matrix plots
  testsentences_svm_combined <- ggplot(conf_matrix_svm_combined, aes(x = Predicted, y = Actual, fill = Color, alpha = Count)) +
    geom_tile(color = "white") +
    coord_equal() +
    geom_text(aes(label = Count), vjust = .5, fontface = "bold", alpha = 1) +
    scale_fill_manual(values = c("green", "red")) +
    scale_alpha(range = c(0.4, 1)) +
    facet_wrap(~ Group, nrow = 2) +
    guides(fill="none")+
    guides(label="none")+
    theme_minimal() +
    labs(title = "SVM")+
    theme(plot.title = element_text(size = 10, face = "bold",hjust = 0.5),
          legend.position="none")
  
  #############
  #{'none': 0, 'emp': 1, 'dur': 2}
  df_pred_semi <- read_csv("data/results/test_sentences_predictions_semi.csv") |> 
    filter(seed == 42) |> 
    mutate(pred_label = case_when(
      pred_seed == 0 ~ "none",
      pred_seed == 1 ~ "emp",
      pred_seed == 2 ~ "dur"
    )) 
  
  conf_matrix_semi_explicit <- df_pred_semi %>%
    separate(dataset_label, c("label_type","training_dataset"),"_") |> 
    filter(label_type == "strong") |> 
    count(training_dataset, expected_label, pred_label) %>%
    rename(Group = training_dataset, Actual = expected_label, Predicted = pred_label, Count = n) %>%
    mutate(Actual = substr(Actual, 1, 1), 
           Predicted = substr(Predicted, 1, 1),
           Color = ifelse(factor(Actual, levels = rev(unique(Actual))) == Predicted, "green", "red")) |> 
    mutate(across(c(Actual, Predicted), ~ case_when(
      . == "d" ~ "a",
      . == "e" ~ "c",
      . == "n" ~ "o",
      TRUE ~ .
    ))) |> 
    add_row(Group = "balanced",Actual = "o", Predicted = "a", Count = 0, Color = "white")
  
  # Create individual confusion matrix plots
  testsentences_semi_explicit <- ggplot(conf_matrix_semi_explicit, aes(x = Predicted, y = Actual, fill = Color, alpha = Count)) +
    geom_tile(color = "white") +
    coord_equal() +
    geom_text(aes(label = Count), vjust = .5, fontface = "bold", alpha = 1) +
    scale_fill_manual(values = c("green", "red","white")) +
    scale_alpha(range = c(0.4, 1)) +
    facet_wrap(~ Group, nrow = 2) +
    guides(fill="none")+
    guides(label="none")+
    theme_minimal() +
    labs(title = "XLM-RoBERTa")+
    theme(plot.title = element_text(size = 10, face = "bold",hjust = 0.5),
          legend.position="none")
  
  conf_matrix_semi_combined <- df_pred_semi %>%
    separate(dataset_label, c("label_type","training_dataset"),"_") |> 
    filter(label_type == "combined") |> 
    count(training_dataset, expected_label, pred_label) %>%
    rename(Group = training_dataset, Actual = expected_label, Predicted = pred_label, Count = n) %>%
    mutate(Actual = substr(Actual, 1, 1), 
           Predicted = substr(Predicted, 1, 1),
           Color = ifelse(factor(Actual, levels = rev(unique(Actual))) == Predicted, "green", "red")) |> 
    mutate(across(c(Actual, Predicted), ~ case_when(
      . == "d" ~ "a",
      . == "e" ~ "c",
      . == "n" ~ "o",
      TRUE ~ .
    )))
  
  # Create individual confusion matrix plots
  testsentences_semi_combined <- ggplot(conf_matrix_semi_combined, aes(x = Predicted, y = Actual, fill = Color, alpha = Count)) +
    geom_tile(color = "white") +
    coord_equal() +
    geom_text(aes(label = Count), vjust = .5, fontface = "bold", alpha = 1) +
    scale_fill_manual(values = c("green", "red")) +
    scale_alpha(range = c(0.4, 1)) +
    facet_wrap(~ Group, nrow = 2) +
    guides(fill="none")+
    guides(label="none")+
    theme_minimal() +
    labs(title = "XLM-RoBERTa")+
    theme(plot.title = element_text(size = 10, face = "bold",hjust = 0.5),
          legend.position="none")
  
  ###
  
  
  pick_first <- function(x) {
    str_extract(x, "^[^,]+") %>% str_trim()
  }  
  
  method_input <-"gpt"
  label_type_input <-"explicit"
  
  df_pred_prompting <- read_csv("data/results/test_sentences_final_prompting_predictions.csv") |> 
    left_join(read_csv("data/test_sentences/test_sentences.csv")) |> 
    pivot_longer(cols = gpt_explicit_zero_prediction:deepseek_implicit_few_prediction) |> 
    separate(name, into = c("method","label_type","prompting_strategy","xx"), sep = "_")  |> 
    mutate(dataset_label = value) |> 
    mutate(label_type = if_else(label_type == "explicit", "explicit","combined")) |> 
    mutate(dataset_label = pick_first(dataset_label)) 
  
  
  count(df_pred_prompting,method, label_type)
  
  
  # Create plots for all combinations
  methods <- c("gpt", "llama", "deepseek")
  label_types <- c("explicit", "combined")
  
  plot_list <- list()
  
  for (m in methods) {
    for (l in label_types) {
      print(m)
      print(l)
      plot_key <- paste(m, l, sep = "_")
      plot_list[[plot_key]] <- create_plot(m, l)
      }
  }
  
  
  # A custom theme for confusion matrix plots
  my_confmat_theme <- theme_minimal(base_size = 12) +
    theme(
      panel.background  = element_blank(),
      plot.background   = element_blank(),
      panel.spacing = unit(-3, "lines"),  # reduce this to tighten space between panels
      
      panel.grid.major  = element_blank(),
      panel.grid.minor  = element_blank(),
      panel.border      = element_rect(color = "grey80", fill = NA, size = 0.1),
      axis.title.x      = element_blank(),
      axis.title.y      = element_blank(),
      legend.position = "none",
      plot.title = element_text(size = 9, face = "bold", hjust = 0.5, vjust = 0),
      plot.margin = margin(0, 0, 0, 0)
    )
  
  
  # Strong Labels row
  plot <- ggarrange(
    testsentences_svm_explicit      + my_confmat_theme + coord_equal(),
    testsentences_semi_explicit     + my_confmat_theme + coord_equal(),
    plot_list[["llama_explicit"]]   + my_confmat_theme + coord_equal() + labs(title = "Llama 3-8B"),
    plot_list[["gpt_explicit"]]     + my_confmat_theme + coord_equal() + labs(title = "GPT-4o"),
    plot_list[["deepseek_explicit"]] + my_confmat_theme + coord_equal() + labs(title = "Deepseek-V3"),
    nrow = 1, align = "hv"
  )
  
  p1 <- annotate_figure(
    plot, 
    top = text_grob("Strong Labels", color = "black", face = "bold", size = 12, vjust = 0.1)
  )
  
  # Combined Labels row
  plot2 <- ggarrange(
    testsentences_svm_combined      + my_confmat_theme + coord_equal(),
    testsentences_semi_combined     + my_confmat_theme + coord_equal(),
    plot_list[["llama_combined"]]   + my_confmat_theme + coord_equal() + labs(title = "Llama 3-8B "),
    plot_list[["gpt_combined"]]     + my_confmat_theme + coord_equal() + labs(title = "GPT-4o"),
    plot_list[["deepseek_combined"]] + my_confmat_theme + coord_equal() + labs(title = "Deepseek-V3"),
    nrow = 1, align = "hv"
  )
  
  p2 <- annotate_figure(
    plot2, 
    top = text_grob("Combined Labels", color = "black", face = "bold", size = 12, vjust = 0.1)
  ) 
  
  margin()
  spacer <- ggplot() + theme_void()
  
  final_plot <- ggarrange(p1, spacer, p2,nrow = 3,heights = c(1, 0.1, 1))
  
  # Display
  print(final_plot)
  
  # Save at high resolution
  ggsave("output/Figure_4_functional_tests.png", final_plot, width = 10, height = 6, dpi = 600,units = "in")

  